#include "nccl_net.h"
#include "transport.h"
#include "transport_config.h"
#include "util_rdma.h"
#include <glog/logging.h>
#include <atomic>
#include <mutex>
#include <thread>
#include <unistd.h>
using namespace uccl;

char const* PLUGIN_NAME = "RDMA_Plugin";

bool volatile quit = false;

void interrupt_handler(int signal) {
  (void)signal;
  quit = true;
}

class ucclRequestBuffPool : public BuffPool {
  static constexpr size_t num_elements = kMaxReq << 2;  // Send and receive.
  static constexpr size_t element_size = sizeof(ucclRequest);

 public:
  ucclRequestBuffPool() : BuffPool(num_elements, element_size, nullptr) {}
  ~ucclRequestBuffPool() = default;
};

std::shared_ptr<RDMAEndpoint> ep;

enum ConnState { kConnInit = 0, kConnConnecting, kConnConnected };

struct ucclBaseComm {
  int dev;
  ConnID conn_id;
  std::shared_ptr<ucclRequestBuffPool> uccl_req_pool;
};

struct AsyncAcceptState {
  struct ucclBaseComm base;
  std::string remote_ip_str;
  int remote_dev;
};

struct AsyncConnectState {
  struct ucclBaseComm base;
};

// Handle generated by pluginListen. And then trasfered to remote side,
// Remote side will call pluginConnect() with this handle.
struct ucclHandle {
  uint32_t ip_addr_u32;
  uint16_t listen_port;
  int remote_dev;
  int remote_gpuidx;
  enum ConnState state = kConnInit;
  AsyncConnectState connect_buffer;
};
static_assert(sizeof(struct ucclHandle) < NCCL_NET_HANDLE_MAXSIZE,
              "ucclHandle size too large");

// Hanlde generated by pluginListen for pluginAccept() to use.
struct ucclListenComm {
  int dev;
  int listen_fd;
  int remote_dev;
  int gpuidx;
  enum ConnState state = kConnInit;
  AsyncAcceptState accept_buffer;
};

// Handle generated by pluginAccept.
struct ucclRecvComm {
  struct ucclBaseComm base;
  std::string remote_ip_str;
  int remote_dev;
};

// Handle generated by pluginConnect.
struct ucclSendComm {
  struct ucclBaseComm base;
};

ncclResult_t pluginInit(ncclDebugLogger_t logFunction) {
  std::cout << "Hello UCCL from PID: " << getpid() << std::endl;

  ep = std::make_shared<RDMAEndpoint>(ucclParamNUM_ENGINES());

  return ncclSuccess;
}

ncclResult_t pluginDevices(int* ndev) {
  *ndev = ep->get_num_devices();
  return ncclSuccess;
}

/// @ref ncclIbGetPciPath
ncclResult_t pluginPciPath(char const* ib_name, char** path) {
  char devicePath[256];
  snprintf(devicePath, 256, "/sys/class/infiniband/%s/device", ib_name);
  char* p = realpath(devicePath, NULL);
  if (p == NULL) {
    LOG(ERROR) << "Could not find device path for " << ib_name;
    return ncclInternalError;
  }
  *path = p;
  return ncclSuccess;
}

#define MAX_STR_LEN 255
ncclResult_t ncclTopoGetStrFromSys(char const* path, char const* fileName,
                                   char* strValue) {
  char filePath[PATH_MAX];
  sprintf(filePath, "%s/%s", path, fileName);
  int offset = 0;
  FILE* file;
  if ((file = fopen(filePath, "r")) != NULL) {
    while (feof(file) == 0 && ferror(file) == 0 && offset < MAX_STR_LEN) {
      int len = fread(strValue + offset, 1, MAX_STR_LEN - offset, file);
      offset += len;
    }
    fclose(file);
  }
  if (offset == 0) {
    strValue[0] = '\0';
    UCCL_LOG_PLUGIN << Format(
        "Topology detection : could not read %s, ignoring", filePath);
  } else {
    strValue[offset - 1] = '\0';
  }
  return ncclSuccess;
}

UCCL_PARAM(IbPciRelaxedOrdering, "IB_PCI_RELAXED_ORDERING", 2);

// Detect whether GDR can work on a given NIC with the current CUDA device
// Returns :
// ncclSuccess : GDR works
// ncclSystemError : no module or module loaded but not supported by GPU
#define KNL_MODULE_LOADED(a) ((access(a, F_OK) == -1) ? 0 : 1)
static int ncclIbGdrModuleLoaded = 0;  // 1 = true, 0 = false
static void ibGdrSupportInitOnce() {
#if defined(__HIP_PLATFORM_AMD__) || defined(__HIPCC__)
  if (ncclIbGdrModuleLoaded == 0) {
    // Check for `memory_peers` directory containing `amdkfd/version`
    // This `memory_peers` directory is created by NIC-GPU driver interaction
    // On Linux kernel 5.15.0 (e.g. Ubuntu 22.04), `memory_peers` is created
    // under `/sys/kernel/mm/` However, on newer kernels like Ubuntu 24.04.1
    // (Linux kernel 6.8.0) or Ubuntu 22.04.4 HWE (Linux kernel 6.5.0), this
    // `memory_peers` directory is either not created (go to else-if condition)
    // or created under a different path like `/sys/kernel/` or `/sys/`
    // (depending on your ib_peer_mem module)
    char const* memory_peers_paths[] = {
        "/sys/kernel/mm/memory_peers/amdkfd/version",
        "/sys/kernel/memory_peers/amdkfd/version",
        "/sys/memory_peers/amdkfd/version", NULL};
    int i = 0;

    while (memory_peers_paths[i]) {
      if (access(memory_peers_paths[i], F_OK) == 0) {
        ncclIbGdrModuleLoaded = 1;
        UCCL_LOG_PLUGIN << Format("Found %s", memory_peers_paths[i]);
        break;
      } else {
        ncclIbGdrModuleLoaded = 0;
      }
      ++i;
    }

    char strValue[MAX_STR_LEN];
    ncclTopoGetStrFromSys("/sys/devices/virtual/dmi/id", "bios_version",
                          strValue);
    if (strncmp("Hyper-V UEFI Release", strValue, 20) == 0) {
      int roMode = ucclParamIbPciRelaxedOrdering();
      ncclTopoGetStrFromSys("/proc/sys/kernel", "numa_balancing", strValue);
      if (strcmp(strValue, "1") == 0 && roMode == 0) ncclIbGdrModuleLoaded = 0;
    }

    if (ncclIbGdrModuleLoaded == 0) {
      // Check for `ib_register_peer_memory_client` symbol in `/proc/kallsyms`
      // if your system uses native OS ib_peer module
      char buf[256];
      FILE* fp = NULL;
      fp = fopen("/proc/kallsyms", "r");

      if (fp == NULL) {
        UCCL_LOG_PLUGIN << "Could not open /proc/kallsyms";
      } else {
        while (fgets(buf, sizeof(buf), fp) != NULL) {
          if (strstr(buf, "t ib_register_peer_memory_client") != NULL ||
              strstr(buf, "T ib_register_peer_memory_client") != NULL) {
            ncclIbGdrModuleLoaded = 1;
            UCCL_LOG_PLUGIN
                << "Found ib_register_peer_memory_client in /proc/kallsyms";
            break;
          }
        }
      }
    }
  }
#else
  // Check for the nv_peer_mem module being loaded
  ncclIbGdrModuleLoaded =
      KNL_MODULE_LOADED("/sys/kernel/mm/memory_peers/nv_mem/version") ||
      KNL_MODULE_LOADED("/sys/kernel/mm/memory_peers/nv_mem_nc/version") ||
      KNL_MODULE_LOADED("/sys/module/nvidia_peermem/version");
#endif
}

ncclResult_t ncclIbGdrSupport() {
  static pthread_once_t once = PTHREAD_ONCE_INIT;
  pthread_once(&once, ibGdrSupportInitOnce);
  if (!ncclIbGdrModuleLoaded) return ncclSystemError;
  return ncclSuccess;
}

ncclResult_t pluginGetProperties(int dev, ncclNetProperties_v8_t* props) {
  auto factory_dev = RDMAFactory::get_factory_dev(dev);
  props->name = factory_dev->ib_name;

  // Speed in *Mbps*. 100000 means 100G
  props->speed = factory_dev->link_bw * 8 / 1e6;

  pluginPciPath(factory_dev->ib_name, &props->pciPath);

  // Only used to detect NICs with multiple PCI attachments.
  props->guid = factory_dev->dev_attr.sys_image_guid;

  props->ptrSupport = NCCL_PTR_HOST;

  // TODO: make this configurable.
  if (ncclIbGdrSupport() == ncclSuccess) {
    props->ptrSupport |= NCCL_PTR_CUDA;
  }

  if (factory_dev->dma_buf_support) props->ptrSupport |= NCCL_PTR_DMABUF;

  if (props->ptrSupport == NCCL_PTR_HOST) {
    DCHECK(0) << "Lack of GPU Direct RDMA support.";
  }

  // If you regMr has a fast registration cache, set to 1. If set to 0, user
  // buffer registration may be disabled.
  props->regIsGlobal = 0;

  // Port number, used in conjunction with guid
  props->port = factory_dev->ib_port_num;
  // Custom latency (used to help tuning if latency is high. If set to 0, use
  // default NCCL values.
  props->latency = 0;
  // Maximum number of comm objects we can create.
  props->maxComms = 1024 * 1024;
  // Maximum number of receive operations taken by irecv().
  props->maxRecvs = 1;
  // Coupling with NCCL network device-side code.
  props->netDeviceType = NCCL_NET_DEVICE_HOST;
  props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION;
  return ncclSuccess;
}

// To create a connection, NCCL will start by calling listen on the receiver
// side. This function takes a device number as input argument, and should
// return a local listenComm object, and a handle to pass to the other side, so
// that the sender side can connect to the receiver. The handle is a buffer of
// size NCCL_NET_HANDLE_MAXSIZE and is provided by NCCL. This call should never
// block, but contrary to connect and accept, listenComm should never be NULL if
// the call succeeds.
ncclResult_t pluginListen(int dev, void* opaqueHandle, void** listenComm) {
  int ret = 0;
  struct ucclHandle* handle = (struct ucclHandle*)opaqueHandle;
  memset(handle, 0, sizeof(struct ucclHandle));

#ifdef LAZY_CREATE_ENGINE
  ep->initialize_engine_by_dev(dev, false);
#endif

  // Create a listening socket.
  int listen_fd = socket(AF_INET, SOCK_STREAM, 0);
  DCHECK(listen_fd >= 0) << "ERROR: opening socket";

  int flag = 1;
  DCHECK(setsockopt(listen_fd, SOL_SOCKET, SO_REUSEADDR, &flag, sizeof(int)) >=
         0);

  struct sockaddr_in serv_addr;
  bzero((char*)&serv_addr, sizeof(serv_addr));
  serv_addr.sin_family = AF_INET;
  serv_addr.sin_addr.s_addr = INADDR_ANY;
  serv_addr.sin_port = 0;  // Let OS assign a free ephemeral port
  ret = bind(listen_fd, (struct sockaddr*)&serv_addr, sizeof(serv_addr));
  if (ret < 0) {
    LOG(ERROR) << "ERROR: binding socket, ret: " << ret
               << ", port: " << ntohs(serv_addr.sin_port) << ", dev: " << dev;
    close(listen_fd);
    return ncclInternalError;
  }
  DCHECK(ret >= 0) << ret;

  // Get the actual port assigned by the OS.
  socklen_t len = sizeof(serv_addr);
  getsockname(listen_fd, (struct sockaddr*)&serv_addr, &len);

  ret = listen(listen_fd, 1);
  DCHECK(ret == 0) << ret;

  // Fill out handle which will be passed to the other side.
  auto factory_dev = RDMAFactory::get_factory_dev(dev);
  handle->ip_addr_u32 = str_to_ip(factory_dev->local_ip_str);
  handle->listen_port = ntohs(serv_addr.sin_port);
  handle->remote_dev = dev;
#ifndef __HIP_PLATFORM_AMD__
  cudaGetDevice(&handle->remote_gpuidx);
#else
  DCHECK(hipGetDevice(&handle->remote_gpuidx) == hipSuccess);
#endif

  struct ucclListenComm* lcomm =
      (struct ucclListenComm*)calloc(1, sizeof(struct ucclListenComm));

  lcomm->dev = dev;
  lcomm->listen_fd = listen_fd;
  lcomm->gpuidx = handle->remote_gpuidx;

  *listenComm = lcomm;

  UCCL_LOG_PLUGIN << "Listen on dev: " << dev << " from PID: " << getpid();

  return ncclSuccess;
}

// NCCL will use its bootstrap infrastructure to provide the handle to the
// sender side, then call connect on the sender side on a given device index
// dev, providing the handle. connect should not block either, and instead set
// sendComm to NULL and return ncclSuccess. In that case, NCCL will call accept
// again until it succeeds.
ncclResult_t pluginConnect(int dev, void* opaque_handle, void** sendComm,
                           ncclNetDeviceHandle_v8_t** /*sendDevComm*/) {
  struct ucclHandle* handle = (struct ucclHandle*)opaque_handle;
  int local_gpuidx;
#ifndef __HIP_PLATFORM_AMD__
  cudaGetDevice(&local_gpuidx);
#else
  DCHECK(hipGetDevice(&local_gpuidx) == hipSuccess);
#endif

  std::string remote_ip_str = ip_to_str(handle->ip_addr_u32);

  struct ucclSendComm* scomm =
      (struct ucclSendComm*)calloc(1, sizeof(struct ucclSendComm));

  if (handle->state == kConnInit) {
    handle->state = kConnConnecting;
    // Delegate connection to another thread.
    std::thread t = std::thread([dev, local_gpuidx, handle, remote_ip_str] {
      handle->connect_buffer.base.conn_id = ep->uccl_connect(
          dev, local_gpuidx, handle->remote_dev, handle->remote_gpuidx,
          remote_ip_str, handle->listen_port);
      handle->connect_buffer.base.dev = dev;
      std::atomic_thread_fence(std::memory_order_release);
      handle->state = kConnConnected;
    });
    t.detach();
    *sendComm = nullptr;
    free(scomm);
  } else if (handle->state == kConnConnecting) {
    *sendComm = nullptr;
    free(scomm);
  } else {
    DCHECK(handle->state == kConnConnected);
    scomm->base = handle->connect_buffer.base;
    scomm->base.uccl_req_pool = std::make_shared<ucclRequestBuffPool>();
    *sendComm = scomm;
  }

  if (*sendComm) {
    UCCL_LOG_PLUGIN << "Connected to " << remote_ip_str << "/"
                    << handle->remote_dev << " on dev:" << dev << ", "
                    << scomm->base.conn_id.flow_id << " from PID: " << getpid();
  }

  return ncclSuccess;
}

// To finalize the connection, the receiver side will call accept on the
// listenComm returned by the listen call previously. If the sender did not
// connect yet, accept should not block. It should return ncclSuccess, setting
// recvComm to NULL. NCCL will call accept again until it succeeds.
ncclResult_t pluginAccept(void* listenComm, void** recvComm,
                          ncclNetDeviceHandle_v8_t** /*recvDevComm*/) {
  struct ucclListenComm* lcomm = (struct ucclListenComm*)listenComm;

  struct ucclRecvComm* rcomm =
      (struct ucclRecvComm*)calloc(1, sizeof(struct ucclRecvComm));

  if (lcomm->state == kConnInit) {
    lcomm->state = kConnConnecting;
    // Delegate connection to another thread.
    std::thread t = std::thread([lcomm] {
      std::string remote_ip_str;
      int remote_dev;
      lcomm->accept_buffer.base.conn_id =
          ep->uccl_accept(lcomm->dev, lcomm->listen_fd, lcomm->gpuidx,
                          remote_ip_str, &remote_dev);
      lcomm->accept_buffer.base.dev = lcomm->dev;
      lcomm->accept_buffer.remote_ip_str = remote_ip_str;
      lcomm->accept_buffer.remote_dev = remote_dev;
      // Ensure kConnConnected is set after all other fields are set.
      std::atomic_thread_fence(std::memory_order_release);
      lcomm->state = kConnConnected;
    });
    t.detach();
    *recvComm = nullptr;
    free(rcomm);
  } else if (lcomm->state == kConnConnecting) {
    *recvComm = nullptr;
    free(rcomm);
  } else {
    DCHECK(lcomm->state == kConnConnected);
    rcomm->base = lcomm->accept_buffer.base;
    rcomm->base.uccl_req_pool = std::make_shared<ucclRequestBuffPool>();
    rcomm->remote_ip_str = lcomm->accept_buffer.remote_ip_str;
    rcomm->remote_dev = lcomm->accept_buffer.remote_dev;
    *recvComm = rcomm;
  }

  if (*recvComm) {
    UCCL_LOG_PLUGIN << "Accepted from " << rcomm->remote_ip_str << "/"
                    << rcomm->remote_dev << " on dev:" << lcomm->dev << ", "
                    << rcomm->base.conn_id.flow_id << " from PID: " << getpid();
  }

  return ncclSuccess;
}

ncclResult_t pluginRegMr(void* collComm, void* data, size_t size, int type,
                         void** mhandle) {
  int ret;
  struct ucclBaseComm* base = (struct ucclBaseComm*)collComm;
  ret = ep->uccl_regmr((UcclFlow*)base->conn_id.context, data, size, type,
                       (struct Mhandle**)mhandle);
  UCCL_LOG_PLUGIN << "RegMr, " << size << ", " << base->conn_id.flow_id
                  << " from PID: " << getpid();

  return ret == 0 ? ncclSuccess : ncclInternalError;
}

ncclResult_t pluginRegMrDmaBuf(void* collComm, void* data, size_t size,
                               int type, uint64_t offset, int fd,
                               void** mhandle) {
  int ret;
  struct ucclBaseComm* base = (struct ucclBaseComm*)collComm;
  ret = ep->uccl_regmr_dmabuf((UcclFlow*)base->conn_id.context, data, size,
                              type, offset, fd, (struct Mhandle**)mhandle);
  UCCL_LOG_PLUGIN << "RegMrDmaBuf, " << size << ", " << base->conn_id.flow_id
                  << " from PID: " << getpid();

  return ret == 0 ? ncclSuccess : ncclInternalError;
}

ncclResult_t pluginDeregMr(void* collComm, void* mhandle) {
  struct ucclBaseComm* base = (struct ucclBaseComm*)collComm;
  ep->uccl_deregmr((struct Mhandle*)mhandle);
  return ncclSuccess;
}

ncclResult_t pluginIsend(void* sendComm, void* data, int size, int tag,
                         void* mhandle, void** request) {
  struct ucclSendComm* scomm = (struct ucclSendComm*)sendComm;
  auto conn_id = scomm->base.conn_id;
  struct Mhandle* mh = (struct Mhandle*)mhandle;

  uint64_t addr;
  auto dev = scomm->base.dev;
  {
    if (scomm->base.uccl_req_pool->alloc_buff(&addr)) {
      *request = nullptr;
      return ncclSuccess;
    }
  }

  struct ucclRequest* req = reinterpret_cast<struct ucclRequest*>(addr);
  if (ep->uccl_send_async((UcclFlow*)conn_id.context, mh, data, size, req)) {
    scomm->base.uccl_req_pool->free_buff(reinterpret_cast<uint64_t>(req));
    *request = nullptr;
    return ncclSuccess;
  }
  req->req_pool = (void*)scomm->base.uccl_req_pool.get();

  *request = req;

  UCCL_LOG_PLUGIN << "Isend on dev: " << dev << ", " << size
                  << "B, ureq ptr:" << req << " from PID: " << getpid();

  return ncclSuccess;
}

ncclResult_t pluginIrecv(void* recvComm, int n, void** data, int* sizes,
                         int* tags, void** mhandles, void** request) {
  struct ucclRecvComm* rcomm = (struct ucclRecvComm*)recvComm;
  auto conn_id = rcomm->base.conn_id;
  struct Mhandle** mhs = (struct Mhandle**)mhandles;

  uint64_t addr;
  auto dev = rcomm->base.dev;
  {
    if (rcomm->base.uccl_req_pool->alloc_buff(&addr)) {
      *request = nullptr;
      return ncclSuccess;
    }
  }

  struct ucclRequest* req = reinterpret_cast<struct ucclRequest*>(addr);
  if (ep->uccl_recv_async((UcclFlow*)conn_id.context, mhs, data, sizes, n,
                          req)) {
    rcomm->base.uccl_req_pool->free_buff(reinterpret_cast<uint64_t>(req));
    *request = nullptr;
    return ncclSuccess;
  }
  req->req_pool = (void*)rcomm->base.uccl_req_pool.get();

  *request = req;

  UCCL_LOG_PLUGIN << "Irecv on dev: " << dev << ", " << sizes[0]
                  << "B, ureq ptr:" << req << " from PID: " << getpid();

  return ncclSuccess;
}

ncclResult_t pluginIflush(void* recvComm, int n, void** data, int* sizes,
                          void** mhandles, void** request) {
  struct ucclRecvComm* rcomm = (struct ucclRecvComm*)recvComm;
  auto conn_id = rcomm->base.conn_id;
  struct Mhandle** mhs = (struct Mhandle**)mhandles;

  uint64_t addr;
  auto dev = rcomm->base.dev;
  {
    if (rcomm->base.uccl_req_pool->alloc_buff(&addr)) {
      *request = nullptr;
      return ncclSuccess;
    }
  }
  struct ucclRequest* req = reinterpret_cast<struct ucclRequest*>(addr);

  if (ep->uccl_flush((UcclFlow*)conn_id.context, mhs, data, sizes, n, req)) {
    rcomm->base.uccl_req_pool->free_buff(reinterpret_cast<uint64_t>(req));
    *request = nullptr;
    return ncclSuccess;
  }

  req->req_pool = (void*)rcomm->base.uccl_req_pool.get();

  *request = req;

  // UCCL_LOG_PLUGIN << "Iflush on dev: " << dev << ", " << sizes[0]
  //                 << "B, ureq ptr:" << req << " from PID: " << getpid();

  return ncclSuccess;
}

ncclResult_t pluginTest(void* request, int* done, int* size) {
  struct ucclRequest* req = reinterpret_cast<struct ucclRequest*>(request);

  if (ep->uccl_poll_ureq_once(req)) {
    *done = 1;
    if (req->type == ReqTx || req->type == ReqTxRC) {
      size[0] = req->send.data_len;
      UCCL_LOG_PLUGIN << "Test Tx done, " << size[0] << "B, ureq ptr:" << req;
    } else if (req->type == ReqRx || req->type == ReqRxRC) {
      for (int i = 0; i < req->n; i++) size[i] = req->recv.data_len[i];
      UCCL_LOG_PLUGIN << "Test Rx done, " << size[0] << "B, ureq ptr:" << req
                      << ", req->type:" << req->type;
    } else if (req->type == ReqFlush) {
      // Do nothing.
      // UCCL_LOG_PLUGIN << "Test Flush done, " << size[0]
      //                 << "B, ureq ptr:" << req;
    }
    {
      auto uccl_req_pool =
          reinterpret_cast<ucclRequestBuffPool*>(req->req_pool);
      uccl_req_pool->free_buff(reinterpret_cast<uint64_t>(req));
    }
  } else {
    *done = 0;
  }

  return ncclSuccess;
}

ncclResult_t pluginCloseSend(void* sendComm) {
  struct ucclSendComm* scomm = (struct ucclSendComm*)sendComm;
  free(scomm);
  return ncclSuccess;
}
ncclResult_t pluginCloseRecv(void* recvComm) {
  struct ucclRecvComm* rcomm = (struct ucclRecvComm*)recvComm;
  free(rcomm);
  return ncclSuccess;
}
ncclResult_t pluginCloseListen(void* listenComm) {
  struct ucclListenComm* comm = (struct ucclListenComm*)listenComm;
  close(comm->listen_fd);
  free(comm);
  return ncclSuccess;
}

ncclNet_v8_t volatile ncclNetPlugin_v8 = {
    .name = PLUGIN_NAME,
    .init = pluginInit,
    .devices = pluginDevices,
    .getProperties = pluginGetProperties,
    .listen = pluginListen,
    .connect = pluginConnect,
    .accept = pluginAccept,
    .regMr = pluginRegMr,
    .regMrDmaBuf = pluginRegMrDmaBuf,
    .deregMr = pluginDeregMr,
    .isend = pluginIsend,
    .irecv = pluginIrecv,
    .iflush = pluginIflush,
    .test = pluginTest,
    .closeSend = pluginCloseSend,
    .closeRecv = pluginCloseRecv,
    .closeListen = pluginCloseListen,
    .getDeviceMr = nullptr,
    .irecvConsumed = nullptr,
};
